{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Human numbers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.text import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=64"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[PosixPath('/home/ubuntu/.fastai/data/human_numbers/train.txt'),\n",
       " PosixPath('/home/ubuntu/.fastai/data/human_numbers/valid.txt')]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "path = untar_data(URLs.HUMAN_NUMBERS)\n",
    "path.ls()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def readnums(d): return [', '.join(o.strip() for o in open(path/d).readlines())]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'one, two, three, four, five, six, seven, eight, nine, ten, eleven, twelve, thirt'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_txt = readnums('train.txt'); train_txt[0][:80]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' nine thousand nine hundred ninety eight, nine thousand nine hundred ninety nine'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valid_txt = readnums('valid.txt'); valid_txt[0][-80:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = TextList(train_txt, path=path)\n",
    "valid = TextList(valid_txt, path=path)\n",
    "\n",
    "src = ItemLists(path=path, train=train, valid=valid).label_for_lm()\n",
    "data = src.databunch(bs=bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'xxbos one , two , three , four , five , six , seven , eight , nine , ten , eleve'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train[0].text[:80]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "13017"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(data.valid_ds[0][0].data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(70, 3)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.bptt, len(data.valid_dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.905580357142857"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "13017/70/bs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "it = iter(data.valid_dl)\n",
    "x1,y1 = next(it)\n",
    "x2,y2 = next(it)\n",
    "x3,y3 = next(it)\n",
    "it.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "13440"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1.numel()+x2.numel()+x3.numel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([64, 70]), torch.Size([64, 70]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1.shape,y1.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([64, 70]), torch.Size([64, 70]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x2.shape,y2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 2,  8, 10, 11, 12, 10,  9,  8,  9, 13, 18, 24, 18, 14, 15, 10, 18,  8,\n",
       "         9,  8, 18, 24, 18, 10, 18, 10,  9,  8, 18, 19, 10, 25, 19, 22, 19, 19,\n",
       "        23, 19, 10, 13, 10, 10,  8, 13,  8, 19,  9, 19, 34, 16, 10,  9,  8, 16,\n",
       "         8, 19,  9, 19, 10, 19, 10, 19, 19, 19], device='cuda:0')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1[:,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([18, 18, 26,  9,  8, 11, 31, 18, 25,  9, 10, 14, 10,  9,  8, 14, 10, 18,\n",
       "        25, 18, 10, 17, 10, 17,  8, 17, 20, 18,  9,  9, 19,  8, 10, 15, 10, 10,\n",
       "        12, 10, 12,  8, 12, 13, 19,  9, 19, 10, 23, 10,  8,  8, 15, 16, 19,  9,\n",
       "        19, 10, 23, 10, 18,  8, 18, 10, 10,  9], device='cuda:0')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y1[:,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "v = data.valid_ds.vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'xxbos eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.textify(x1[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.textify(y1[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'thousand eighteen , eight thousand nineteen , eight thousand twenty , eight thousand twenty one , eight thousand twenty two , eight thousand twenty three , eight thousand twenty four , eight thousand twenty five , eight thousand twenty six , eight thousand twenty seven , eight thousand twenty eight , eight thousand twenty nine , eight thousand thirty , eight thousand thirty one , eight thousand thirty two ,'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.textify(x2[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'eight thousand thirty three , eight thousand thirty four , eight thousand thirty five , eight thousand thirty six , eight thousand thirty seven , eight thousand thirty eight , eight thousand thirty nine , eight thousand forty , eight thousand forty one , eight thousand forty two , eight thousand forty three , eight thousand forty four , eight thousand forty five , eight thousand forty six , eight'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.textify(x3[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "', eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine ,'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.textify(x1[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'eight thousand sixty , eight thousand sixty one , eight thousand sixty two , eight thousand sixty three , eight thousand sixty four , eight thousand sixty five , eight thousand sixty six , eight thousand sixty seven , eight thousand sixty eight , eight thousand sixty nine , eight thousand seventy , eight thousand seventy one , eight thousand seventy two , eight thousand seventy three , eight thousand'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.textify(x2[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'seventy four , eight thousand seventy five , eight thousand seventy six , eight thousand seventy seven , eight thousand seventy eight , eight thousand seventy nine , eight thousand eighty , eight thousand eighty one , eight thousand eighty two , eight thousand eighty three , eight thousand eighty four , eight thousand eighty five , eight thousand eighty six , eight thousand eighty seven , eight thousand eighty'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.textify(x3[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'ninety , nine thousand nine hundred ninety one , nine thousand nine hundred ninety two , nine thousand nine hundred ninety three , nine thousand nine hundred ninety four , nine thousand nine hundred ninety five , nine thousand nine hundred ninety six , nine thousand nine hundred ninety seven , nine thousand nine hundred ninety eight , nine thousand nine hundred ninety nine xxbos eight thousand one , eight'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.textify(x3[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table>  <col width='5%'>  <col width='95%'>  <tr>\n",
       "    <th>idx</th>\n",
       "    <th>text</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>0</th>\n",
       "    <th>thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine , eight thousand sixty , eight thousand sixty</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>eight , eight thousand eighty nine , eight thousand ninety , eight thousand ninety one , eight thousand ninety two , eight thousand ninety three , eight thousand ninety four , eight thousand ninety five , eight thousand ninety six , eight thousand ninety seven , eight thousand ninety eight , eight thousand ninety nine , eight thousand one hundred , eight thousand one hundred one , eight thousand one</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>2</th>\n",
       "    <th>thousand one hundred twenty four , eight thousand one hundred twenty five , eight thousand one hundred twenty six , eight thousand one hundred twenty seven , eight thousand one hundred twenty eight , eight thousand one hundred twenty nine , eight thousand one hundred thirty , eight thousand one hundred thirty one , eight thousand one hundred thirty two , eight thousand one hundred thirty three , eight thousand</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>3</th>\n",
       "    <th>three , eight thousand one hundred fifty four , eight thousand one hundred fifty five , eight thousand one hundred fifty six , eight thousand one hundred fifty seven , eight thousand one hundred fifty eight , eight thousand one hundred fifty nine , eight thousand one hundred sixty , eight thousand one hundred sixty one , eight thousand one hundred sixty two , eight thousand one hundred sixty three</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>4</th>\n",
       "    <th>thousand one hundred eighty three , eight thousand one hundred eighty four , eight thousand one hundred eighty five , eight thousand one hundred eighty six , eight thousand one hundred eighty seven , eight thousand one hundred eighty eight , eight thousand one hundred eighty nine , eight thousand one hundred ninety , eight thousand one hundred ninety one , eight thousand one hundred ninety two , eight thousand</th>\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "data.show_batch(ds_type=DatasetType.Valid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Single fully connected model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = src.databunch(bs=bs, bptt=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([64, 3]), torch.Size([64, 3]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x,y = data.one_batch()\n",
    "x.shape,y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "38"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nv = len(v.itos); nv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nh=64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss4(input,target): return F.cross_entropy(input, target[:,-1])\n",
    "def acc4 (input,target): return accuracy(input, target[:,-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model0(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.i_h = nn.Embedding(nv,nh)  # green arrow\n",
    "        self.h_h = nn.Linear(nh,nh)     # brown arrow\n",
    "        self.h_o = nn.Linear(nh,nv)     # blue arrow\n",
    "        self.bn = nn.BatchNorm1d(nh)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        h = self.bn(F.relu(self.h_h(self.i_h(x[:,0]))))\n",
    "        if x.shape[1]>1:\n",
    "            h = h + self.i_h(x[:,1])\n",
    "            h = self.bn(F.relu(self.h_h(h)))\n",
    "        if x.shape[1]>2:\n",
    "            h = h + self.i_h(x[:,2])\n",
    "            h = self.bn(F.relu(self.h_h(h)))\n",
    "        return self.h_o(h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(data, Model0(), loss_func=loss4, metrics=acc4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Total time: 00:07 <p><table style='width:300px; margin-bottom:10px'>\n",
       "  <tr>\n",
       "    <th>epoch</th>\n",
       "    <th>train_loss</th>\n",
       "    <th>valid_loss</th>\n",
       "    <th>acc4</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>3.596286</th>\n",
       "    <th>3.588869</th>\n",
       "    <th>0.046645</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>2</th>\n",
       "    <th>3.086100</th>\n",
       "    <th>3.205763</th>\n",
       "    <th>0.274816</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>3</th>\n",
       "    <th>2.494411</th>\n",
       "    <th>2.749365</th>\n",
       "    <th>0.392004</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>4</th>\n",
       "    <th>2.144753</th>\n",
       "    <th>2.463537</th>\n",
       "    <th>0.415671</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>5</th>\n",
       "    <th>2.010915</th>\n",
       "    <th>2.352887</th>\n",
       "    <th>0.409237</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>6</th>\n",
       "    <th>1.983992</th>\n",
       "    <th>2.336967</th>\n",
       "    <th>0.408778</th>\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(6, 1e-4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Same thing with a loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model1(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.i_h = nn.Embedding(nv,nh)  # green arrow\n",
    "        self.h_h = nn.Linear(nh,nh)     # brown arrow\n",
    "        self.h_o = nn.Linear(nh,nv)     # blue arrow\n",
    "        self.bn = nn.BatchNorm1d(nh)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        h = torch.zeros(x.shape[0], nh).to(device=x.device)\n",
    "        for i in range(x.shape[1]):\n",
    "            h = h + self.i_h(x[:,i])\n",
    "            h = self.bn(F.relu(self.h_h(h)))\n",
    "        return self.h_o(h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(data, Model1(), loss_func=loss4, metrics=acc4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Total time: 00:07 <p><table style='width:300px; margin-bottom:10px'>\n",
       "  <tr>\n",
       "    <th>epoch</th>\n",
       "    <th>train_loss</th>\n",
       "    <th>valid_loss</th>\n",
       "    <th>acc4</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>3.493525</th>\n",
       "    <th>3.420231</th>\n",
       "    <th>0.156250</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>2</th>\n",
       "    <th>2.987600</th>\n",
       "    <th>2.937893</th>\n",
       "    <th>0.376149</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>3</th>\n",
       "    <th>2.440199</th>\n",
       "    <th>2.477995</th>\n",
       "    <th>0.388787</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>4</th>\n",
       "    <th>2.132837</th>\n",
       "    <th>2.256569</th>\n",
       "    <th>0.391774</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>5</th>\n",
       "    <th>2.011305</th>\n",
       "    <th>2.181337</th>\n",
       "    <th>0.392923</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>6</th>\n",
       "    <th>1.985913</th>\n",
       "    <th>2.170874</th>\n",
       "    <th>0.393153</th>\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(6, 1e-4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multi fully connected model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = src.databunch(bs=bs, bptt=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([64, 20]), torch.Size([64, 20]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x,y = data.one_batch()\n",
    "x.shape,y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model2(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.i_h = nn.Embedding(nv,nh)\n",
    "        self.h_h = nn.Linear(nh,nh)\n",
    "        self.h_o = nn.Linear(nh,nv)\n",
    "        self.bn = nn.BatchNorm1d(nh)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        h = torch.zeros(x.shape[0], nh).to(device=x.device)\n",
    "        res = []\n",
    "        for i in range(x.shape[1]):\n",
    "            h = h + self.i_h(x[:,i])\n",
    "            h = F.relu(self.h_h(h))\n",
    "            res.append(self.h_o(self.bn(h)))\n",
    "        return torch.stack(res, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(data, Model2(), metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Total time: 00:06 <p><table style='width:300px; margin-bottom:10px'>\n",
       "  <tr>\n",
       "    <th>epoch</th>\n",
       "    <th>train_loss</th>\n",
       "    <th>valid_loss</th>\n",
       "    <th>accuracy</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>3.639285</th>\n",
       "    <th>3.709278</th>\n",
       "    <th>0.058949</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>2</th>\n",
       "    <th>3.551151</th>\n",
       "    <th>3.565677</th>\n",
       "    <th>0.151776</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>3</th>\n",
       "    <th>3.439908</th>\n",
       "    <th>3.431850</th>\n",
       "    <th>0.207741</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>4</th>\n",
       "    <th>3.323083</th>\n",
       "    <th>3.314237</th>\n",
       "    <th>0.283949</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>5</th>\n",
       "    <th>3.213422</th>\n",
       "    <th>3.219906</th>\n",
       "    <th>0.321662</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>6</th>\n",
       "    <th>3.119673</th>\n",
       "    <th>3.151162</th>\n",
       "    <th>0.336790</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>7</th>\n",
       "    <th>3.046645</th>\n",
       "    <th>3.106630</th>\n",
       "    <th>0.341690</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>8</th>\n",
       "    <th>2.995379</th>\n",
       "    <th>3.082552</th>\n",
       "    <th>0.346662</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>9</th>\n",
       "    <th>2.963800</th>\n",
       "    <th>3.073327</th>\n",
       "    <th>0.349645</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>10</th>\n",
       "    <th>2.947312</th>\n",
       "    <th>3.071951</th>\n",
       "    <th>0.349787</th>\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(10, 1e-4, pct_start=0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Maintain state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model3(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.i_h = nn.Embedding(nv,nh)\n",
    "        self.h_h = nn.Linear(nh,nh)\n",
    "        self.h_o = nn.Linear(nh,nv)\n",
    "        self.bn = nn.BatchNorm1d(nh)\n",
    "        self.h = torch.zeros(bs, nh).cuda()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        res = []\n",
    "        h = self.h\n",
    "        for i in range(x.shape[1]):\n",
    "            h = h + self.i_h(x[:,i])\n",
    "            h = F.relu(self.h_h(h))\n",
    "            res.append(self.bn(h))\n",
    "        self.h = h.detach()\n",
    "        res = torch.stack(res, dim=1)\n",
    "        res = self.h_o(res)\n",
    "        return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(data, Model3(), metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Total time: 00:11 <p><table style='width:300px; margin-bottom:10px'>\n",
       "  <tr>\n",
       "    <th>epoch</th>\n",
       "    <th>train_loss</th>\n",
       "    <th>valid_loss</th>\n",
       "    <th>accuracy</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>3.598183</th>\n",
       "    <th>3.556362</th>\n",
       "    <th>0.050710</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>2</th>\n",
       "    <th>3.274616</th>\n",
       "    <th>2.975699</th>\n",
       "    <th>0.401634</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>3</th>\n",
       "    <th>2.624206</th>\n",
       "    <th>2.036894</th>\n",
       "    <th>0.467330</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>4</th>\n",
       "    <th>2.022702</th>\n",
       "    <th>1.956439</th>\n",
       "    <th>0.316193</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>5</th>\n",
       "    <th>1.681813</th>\n",
       "    <th>1.934952</th>\n",
       "    <th>0.336861</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>6</th>\n",
       "    <th>1.453007</th>\n",
       "    <th>1.948201</th>\n",
       "    <th>0.351349</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>7</th>\n",
       "    <th>1.276971</th>\n",
       "    <th>2.005776</th>\n",
       "    <th>0.368679</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>8</th>\n",
       "    <th>1.138499</th>\n",
       "    <th>2.081261</th>\n",
       "    <th>0.360156</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>9</th>\n",
       "    <th>1.029217</th>\n",
       "    <th>2.145853</th>\n",
       "    <th>0.360795</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>10</th>\n",
       "    <th>0.939949</th>\n",
       "    <th>2.215388</th>\n",
       "    <th>0.372230</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>11</th>\n",
       "    <th>0.865441</th>\n",
       "    <th>2.240438</th>\n",
       "    <th>0.401491</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>12</th>\n",
       "    <th>0.805310</th>\n",
       "    <th>2.195846</th>\n",
       "    <th>0.409375</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>13</th>\n",
       "    <th>0.755035</th>\n",
       "    <th>2.324373</th>\n",
       "    <th>0.422727</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>14</th>\n",
       "    <th>0.713073</th>\n",
       "    <th>2.305542</th>\n",
       "    <th>0.449716</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>15</th>\n",
       "    <th>0.677393</th>\n",
       "    <th>2.350155</th>\n",
       "    <th>0.446449</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>16</th>\n",
       "    <th>0.645841</th>\n",
       "    <th>2.418738</th>\n",
       "    <th>0.446591</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>17</th>\n",
       "    <th>0.621809</th>\n",
       "    <th>2.456903</th>\n",
       "    <th>0.446165</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>18</th>\n",
       "    <th>0.605300</th>\n",
       "    <th>2.541699</th>\n",
       "    <th>0.443040</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>19</th>\n",
       "    <th>0.594099</th>\n",
       "    <th>2.539824</th>\n",
       "    <th>0.443040</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>20</th>\n",
       "    <th>0.587563</th>\n",
       "    <th>2.551423</th>\n",
       "    <th>0.442827</th>\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(20, 3e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## nn.RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model4(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.i_h = nn.Embedding(nv,nh)\n",
    "        self.rnn = nn.RNN(nh,nh, batch_first=True)\n",
    "        self.h_o = nn.Linear(nh,nv)\n",
    "        self.bn = BatchNorm1dFlat(nh)\n",
    "        self.h = torch.zeros(1, bs, nh).cuda()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        res,h = self.rnn(self.i_h(x), self.h)\n",
    "        self.h = h.detach()\n",
    "        return self.h_o(self.bn(res))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(data, Model4(), metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Total time: 00:04 <p><table style='width:300px; margin-bottom:10px'>\n",
       "  <tr>\n",
       "    <th>epoch</th>\n",
       "    <th>train_loss</th>\n",
       "    <th>valid_loss</th>\n",
       "    <th>accuracy</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>3.451432</th>\n",
       "    <th>3.268344</th>\n",
       "    <th>0.224148</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>2</th>\n",
       "    <th>2.974938</th>\n",
       "    <th>2.456569</th>\n",
       "    <th>0.466051</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>3</th>\n",
       "    <th>2.316732</th>\n",
       "    <th>1.946969</th>\n",
       "    <th>0.465625</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>4</th>\n",
       "    <th>1.866151</th>\n",
       "    <th>1.991952</th>\n",
       "    <th>0.314702</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>5</th>\n",
       "    <th>1.618516</th>\n",
       "    <th>1.802403</th>\n",
       "    <th>0.437216</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>6</th>\n",
       "    <th>1.411517</th>\n",
       "    <th>1.731107</th>\n",
       "    <th>0.436293</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>7</th>\n",
       "    <th>1.171916</th>\n",
       "    <th>1.655979</th>\n",
       "    <th>0.504048</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>8</th>\n",
       "    <th>0.965887</th>\n",
       "    <th>1.579963</th>\n",
       "    <th>0.522088</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>9</th>\n",
       "    <th>0.797046</th>\n",
       "    <th>1.479819</th>\n",
       "    <th>0.565057</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>10</th>\n",
       "    <th>0.659378</th>\n",
       "    <th>1.487831</th>\n",
       "    <th>0.579048</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>11</th>\n",
       "    <th>0.553282</th>\n",
       "    <th>1.441922</th>\n",
       "    <th>0.597798</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>12</th>\n",
       "    <th>0.475167</th>\n",
       "    <th>1.498148</th>\n",
       "    <th>0.600781</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>13</th>\n",
       "    <th>0.416131</th>\n",
       "    <th>1.546984</th>\n",
       "    <th>0.606463</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>14</th>\n",
       "    <th>0.372395</th>\n",
       "    <th>1.594261</th>\n",
       "    <th>0.607386</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>15</th>\n",
       "    <th>0.337093</th>\n",
       "    <th>1.578321</th>\n",
       "    <th>0.613352</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>16</th>\n",
       "    <th>0.311385</th>\n",
       "    <th>1.580973</th>\n",
       "    <th>0.623366</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>17</th>\n",
       "    <th>0.292869</th>\n",
       "    <th>1.625745</th>\n",
       "    <th>0.618253</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>18</th>\n",
       "    <th>0.279486</th>\n",
       "    <th>1.623960</th>\n",
       "    <th>0.626065</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>19</th>\n",
       "    <th>0.270054</th>\n",
       "    <th>1.682090</th>\n",
       "    <th>0.611719</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>20</th>\n",
       "    <th>0.263857</th>\n",
       "    <th>1.675676</th>\n",
       "    <th>0.614702</th>\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(20, 3e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2-layer GRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model5(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.i_h = nn.Embedding(nv,nh)\n",
    "        self.rnn = nn.GRU(nh, nh, 2, batch_first=True)\n",
    "        self.h_o = nn.Linear(nh,nv)\n",
    "        self.bn = BatchNorm1dFlat(nh)\n",
    "        self.h = torch.zeros(2, bs, nh).cuda()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        res,h = self.rnn(self.i_h(x), self.h)\n",
    "        self.h = h.detach()\n",
    "        return self.h_o(self.bn(res))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(data, Model5(), metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Total time: 00:02 <p><table style='width:300px; margin-bottom:10px'>\n",
       "  <tr>\n",
       "    <th>epoch</th>\n",
       "    <th>train_loss</th>\n",
       "    <th>valid_loss</th>\n",
       "    <th>accuracy</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>2.864854</th>\n",
       "    <th>2.314943</th>\n",
       "    <th>0.454545</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>2</th>\n",
       "    <th>1.798988</th>\n",
       "    <th>1.357116</th>\n",
       "    <th>0.629688</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>3</th>\n",
       "    <th>0.932729</th>\n",
       "    <th>1.307463</th>\n",
       "    <th>0.796733</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>4</th>\n",
       "    <th>0.451969</th>\n",
       "    <th>1.329699</th>\n",
       "    <th>0.788636</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>5</th>\n",
       "    <th>0.225787</th>\n",
       "    <th>1.293570</th>\n",
       "    <th>0.800142</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>6</th>\n",
       "    <th>0.118085</th>\n",
       "    <th>1.265926</th>\n",
       "    <th>0.803338</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>7</th>\n",
       "    <th>0.065306</th>\n",
       "    <th>1.207096</th>\n",
       "    <th>0.806960</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>8</th>\n",
       "    <th>0.038098</th>\n",
       "    <th>1.205361</th>\n",
       "    <th>0.813920</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>9</th>\n",
       "    <th>0.024069</th>\n",
       "    <th>1.239411</th>\n",
       "    <th>0.807813</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>10</th>\n",
       "    <th>0.017078</th>\n",
       "    <th>1.253409</th>\n",
       "    <th>0.807102</th>\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(10, 1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## fin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
